{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make Stable Diffusion 3x Faster with DeepCache"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/sd_deepcache.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "This tutorial demonstrates how to use the ``pruna`` package to reduce the latency of any U-Net–based diffusion model with :doc:`DeepCache </compression>`.\n",
    "We use the ``stable-diffusion-v1-4`` model as an example, although the tutorial also applies to other popular diffusion models, such as ``SD-XL``.\n",
    "To accelerate transformer-based diffusion models, check out the ``pruna_pro`` tutorial :doc:`\"Make Any Diffusion Model 3x Faster with Auto Caching\" </docs_pruna_pro/tutorials/sd_auto_caching>`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
    "# the following command will install the latest version of pruna\n",
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Loading the Stable Diffusion Model\n",
    "\n",
    "First, load your diffusion model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "\n",
    "# Define the model ID\n",
    "model_id = \"CompVis/stable-diffusion-v1-4\"\n",
    "\n",
    "# Load the pre-trained model\n",
    "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)\n",
    "pipe = pipe.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Initializing the Smash Config"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext"
   },
   "source": [
    "Next, initialize the smash config. In this example, we use :doc:`DeepCache </compression>`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import SmashConfig\n",
    "\n",
    "# Initialize the SmashConfig\n",
    "smash_config = SmashConfig([\"deepcache\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Smashing the Model\n",
    "\n",
    "Now, smash the model. This only takes a few seconds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import smash\n",
    "\n",
    "# Smash the model\n",
    "smashed_model = smash(\n",
    "    model=pipe,\n",
    "    smash_config=smash_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Running the Model\n",
    "\n",
    "Finally, run the model to generate the image with accelerated inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the prompt\n",
    "prompt = \"a fruit basket\"\n",
    "\n",
    "# Display the result\n",
    "smashed_model(prompt).images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Wrap Up"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "raw_mimetype": "text/restructuredtext",
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Congratulations! You have successfully smashed a Stable Diffusion model! You can now use the ``pruna`` package to optimize any U-Net–based diffusion model.\n",
    "The only parts that you should modify are step 1 and step 4 to fit your use case. Is the image quality not good enough? Or do you want to use caching with diffusion transformers such as ``FLUX`` or ``Hunyuan Video``?\n",
    "Then check out the ``pruna_pro`` tutorial :doc:`\"Make Any Diffusion Model 3x Faster with Auto Caching\" </docs_pruna_pro/tutorials/sd_auto_caching>` to take your optimization one step further."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "prunatree",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
